Skip to content

Enable fp16 channelwise convolution#4808

Open
klin2024 wants to merge 4 commits intodevelopfrom
enable_fp16_channelwise
Open

Enable fp16 channelwise convolution#4808
klin2024 wants to merge 4 commits intodevelopfrom
enable_fp16_channelwise

Conversation

@klin2024
Copy link
Copy Markdown
Contributor

Motivation

Extend the channelwise convolution prefuse pass to match half_type in addition to float_type.

Technical Details

Allow the channelwise convolution prefuse pass to recognize half_type.

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

Allow channelwise convolution prefuse pass to match half_type
in addition to float_type.
@klin2024 klin2024 marked this pull request as ready for review April 21, 2026 01:35
@klin2024 klin2024 requested a review from causten as a code owner April 21, 2026 01:35
Copilot AI review requested due to automatic review settings April 21, 2026 01:35
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Extends the GPU channelwise-convolution prefuse pattern so FP16 (half) convolutions can be recognized and rewritten to gpu::channelwise_conv, matching the existing FP32 behavior.

Changes:

  • Allow find_channelwise_convolution in gpu::prefuse_ops to match shape::half_type inputs in addition to shape::float_type.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/targets/gpu/prefuse_ops.cpp Outdated
Comment on lines +304 to +305
if(input->get_shape().type() != shape::float_type and
input->get_shape().type() != shape::half_type)
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition calls input->get_shape().type() twice and hard-codes two comparisons. For readability and to avoid repeating work, consider caching the type in a local (e.g., auto t = …) and checking membership (e.g., via contains({shape::float_type, shape::half_type}, t)) as done in other GPU type gates (e.g., src/targets/gpu/lowering.cpp:349).

Copilot uses AI. Check for mistakes.
@causten causten requested a review from pfultz2 April 21, 2026 01:58
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 21, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4808      +/-   ##
===========================================
+ Coverage    92.46%   92.52%   +0.05%     
===========================================
  Files          583      583              
  Lines        29551    29967     +416     
===========================================
+ Hits         27323    27724     +401     
- Misses        2228     2243      +15     

see 24 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@causten
Copy link
Copy Markdown
Collaborator

causten commented Apr 21, 2026

Test Batch New Rate (85b943) Old Rate (008ed1)* Diff Status
torchvision-resnet50 64 3,175.93 3,160.43 0.49%
torchvision-resnet50_fp16 64 6,683.77 6,675.96 0.12%
torchvision-densenet121 32 2,688.38 2,689.17 -0.03%
torchvision-densenet121_fp16 32 4,530.57 4,521.82 0.19%
torchvision-inceptionv3 32 1,812.81 1,815.61 -0.15%
torchvision-inceptionv3_fp16 32 2,810.96 2,811.65 -0.02%
cadene-inceptionv4 16 834.98 835.09 -0.01%
cadene-resnext64x4 16 782.85 782.86 -0.00%
slim-mobilenet 64 8,383.13 8,379.55 0.04%
slim-nasnetalarge 64 229.72 199.87 14.93% 🔆
slim-resnet50v2 64 3,259.10 3,068.76 6.20% 🔆
bert-mrpc-onnx 8 1,132.66 1,132.91 -0.02%
bert-mrpc-tf 1 450.50 438.09 2.83%
pytorch-examples-wlang-gru 1 325.63 322.32 1.02%
pytorch-examples-wlang-lstm 1 446.90 450.39 -0.77%
torchvision-resnet50_1 1 740.40 641.47 15.42% 🔆
cadene-dpn92_1 1 435.25 376.55 15.59% 🔆
cadene-resnext101_1 1 368.88 130.30 183.09% 🔆
onnx-taau-downsample 1 400.92 196.35 104.19% 🔆
dlrm-criteoterabyte 1 33.12 33.05 0.22%
dlrm-criteoterabyte_fp16 1 52.17 31.15 67.50% 🔆
agentmodel 1 8,838.53 1,103.17 701.19% 🔆
unet_fp16 2 57.34 29.35 95.35% 🔆
resnet50v1_fp16 1 945.92 137.04 590.24% 🔆
resnet50v1_int8 1 914.95 189.33 383.25% 🔆
bert_base_cased_fp16 64 1,116.26 1,006.24 10.93% 🔆
bert_large_uncased_fp16 32 348.89 348.41 0.14%
bert_large_fp16 1 205.71 129.79 58.49% 🔆
distilgpt2_fp16 16 2,116.80 2,117.38 -0.03%
yolov5s 1 577.99 377.20 53.23% 🔆
tinyllama 1 46.20 46.21 -0.03%
vicuna-fastchat 1 44.18 44.19 -0.03%
whisper-tiny-encoder 1 419.13 419.10 0.01%
whisper-tiny-decoder 1 414.39 411.33 0.74%
llama2_7b 1 20.51 20.00 2.52%
qwen1.5-7b 1 23.78 10.66 122.97% 🔆
phi3-3.8b 1 26.93 26.93 -0.02%
llama3-8b 1 21.90 21.75 0.68%
whisper-large-encoder 1 10.29 10.30 -0.07%
whisper-large-decoder 1 102.94 22.49 357.66% 🔆
mistral-7b 1 24.02 17.06 40.80% 🔆
FLUX.1-schnell 1 743.96 753.65 -1.29%

Check flagged results 🔆

* No develop baseline was found for this PR's branch point; compared against the latest available develop run instead.

@causten
Copy link
Copy Markdown
Collaborator

causten commented Apr 21, 2026

Test Status Result
bert-mrpc-onnx PASSED: MIGraphX meets tolerance
bert-mrpc-tf PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-gru PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-lstm PASSED: MIGraphX meets tolerance
dlrm-criteoterabyte PASSED: MIGraphX meets tolerance
agentmodel PASSED: MIGraphX meets tolerance
unet PASSED: MIGraphX meets tolerance
resnet50v1 PASSED: MIGraphX meets tolerance
bert_base_cased_fp16 PASSED: MIGraphX meets tolerance
bert_large_uncased_fp16 🔴 FAILED: MIGraphX is not within tolerance - check verbose output
bert_large PASSED: MIGraphX meets tolerance
yolov5s PASSED: MIGraphX meets tolerance
tinyllama PASSED: MIGraphX meets tolerance
vicuna-fastchat PASSED: MIGraphX meets tolerance
whisper-tiny-encoder PASSED: MIGraphX meets tolerance
whisper-tiny-decoder PASSED: MIGraphX meets tolerance
distilgpt2_fp16 PASSED: MIGraphX meets tolerance
llama2_7b PASSED: MIGraphX meets tolerance
qwen1.5-7b PASSED: MIGraphX meets tolerance
phi3-3.8b PASSED: MIGraphX meets tolerance
llama3-8b PASSED: MIGraphX meets tolerance
whisper-large-encoder ERROR - check error output
traceback
Traceback (most recent call last):
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 377, in
main()
File "/src/AMDMIGraphX/tools/accuracy/accuracy_checker.py", line 224, in main
model = migraphx.parse_onnx(model_name, default_dim_value=batch)
RuntimeError: /data/src/include/migraphx/op/convolution.hpp:103: normalize_compute_shape: CONVOLUTION: mismatched channel numbers
whisper-large-decoder PASSED: MIGraphX meets tolerance
mistral-7b PASSED: MIGraphX meets tolerance
FLUX.1-schnell PASSED: MIGraphX meets tolerance

@klin2024
Copy link
Copy Markdown
Contributor Author

We observed that #4806 and #4804 have the same failure. We will wait a few days and rebase this PR again.

@pfultz2
Copy link
Copy Markdown
Collaborator

pfultz2 commented Apr 25, 2026

On Navi48, fp16 is sometimes slower:

Input Weight Count Baseline (ms) Time (ms) Change
1x3x22x22 3x1x7x7 10 0.0029 0.0035 -17.1%
1x3x24x24 3x1x13x13 6 0.0032 0.0043 -25.6%
1x3x38x38 3x1x7x7 4 0.0031 0.0037 -16.2%
1x3x512x514 3x1x1x3 4 0.0088 0.0055 +60.0%
1x3x514x512 3x1x3x1 4 0.0088 0.0054 +63.0%
1x3x256x258 3x1x1x3 4 0.0043 0.0035 +22.9%
1x3x258x256 3x1x3x1 4 0.0043 0.0036 +19.4%
1x3x128x130 3x1x1x3 4 0.0032 0.0031 +3.2%
1x3x130x128 3x1x3x1 4 0.0031 0.0031 +0.0%
1x3x64x66 3x1x1x3 4 0.0034 0.0035 -2.9%
1x3x66x64 3x1x3x1 4 0.0030 0.0036 -16.7%
1x3x32x34 3x1x1x3 4 0.0035 0.0030 +16.7%
1x3x34x32 3x1x3x1 4 0.0029 0.0030 -3.3%
1x3x97x97 3x1x7x7 1 0.0042 0.0040 +5.0%
1x3x96x96 3x1x7x7 1 0.0040 0.0040 +0.0%

@weizhu12-amd weizhu12-amd force-pushed the enable_fp16_channelwise branch from ef2fc84 to 85b943f Compare April 27, 2026 08:32
Allow channelwise convolution prefuse pass to match half_type
in addition to float_type. For small shape(CWH < 48k), still use
mlir_conv.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants